"""
compute cosine similarity between patent pairs
process in chunks
please specify 
Shaoyu Liu (modified by Zihao Li, 06/2024)

patent: n_patent X 768

Generate:
1. csv file like sim_{year}_{i}_top{top_k}.csv
                 sim_{year}_{i}_cutoff{cutoff}.csv
                 sim_{year}_{i}_sample{m_sample}.csv
2. a dictionary of patent matching patent_id and index. 
    To generate patent dictionary, please call append_patent_id function.
    patent_dict.csv (cover from year 1976 to 2015 patents)

example code:
    cd /Volumes/Zihao_SSD2/PatentsView \
    python3 code/compute_similarity.py \
        --path_patent patentsberta/patentsberta_embedding_matrix.npy \
        --path_patent_dict patentsberta/patent_dict.csv \
        --save_dir patentsberta/patent_sim_combined \
        --chunk_size 1000

example code to restrict to KPSS patents:
    cd /Volumes/Zihao_SSD2/PatentsView \
    python3 code/compute_similarity.py \
        --path_patent patentsberta/patentsberta_embedding_matrix_kpss.npy \
        --path_patent_dict patentsberta/patent_dict_kpss.csv \
        --save_dir patentsberta/patent_sim_kpss \
        --chunk_size 1000

"""

import time
import pandas as pd
import os
import argparse
import numpy as np
import glob
from typing import List, Dict
import json
import logging
from scipy.sparse import load_npz, save_npz
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Dict

os.chdir(r'/Volumes/Zihao_SSD2/PatentsView/')

global WORK_DIR

def create_year_idx(path_patent_dict, min_year=1976, max_year=2015) -> dict:
    """
    create index for subseting patent/paper matrices for each year's patents
    Args:
        path_patent_dict: patent_dict.csv [patent_id, idx, patent_year]
        
    Returns:
        year_idx
        
    """
    patent_dict = pd.read_csv(WORK_DIR+path_patent_dict, usecols=['patent_id','patent_year','idx'])
    patent_dict.rename(columns={'patent_year':'year'}, inplace=True)
    assert patent_dict['idx'].max() == len(patent_dict)-1, 'multiple index error'

    year_idx = {}
    for yr in range(min_year, max_year+1):
        #get min and max index corresponding to patent_matrix in a year
        min_idx = patent_dict[patent_dict['year']==yr]['idx'].min()
        max_idx = patent_dict[patent_dict['year']==yr]['idx'].max() + 1

        year_idx[yr] = {'patent_idx': (min_idx, max_idx)}
    # print(year_idx)
    return year_idx


def compute_and_save_similarities(
    patent: np.ndarray,
    year_idx: dict,
    year: int,
    save_dir: str,
    chunk_size: int,
    top_k: int,
    cutoff: float,
    m_sample: int,
    start_idx: int,
) -> None:
    """
    Compute cosine similarity between patents from a given year and patents published before that year.
    Save results using three different selection criteria: top-k, cutoff, and random sampling.

    Args:
        patent: a numpy array of patent embeddings
        year_idx: dictionary of {'year': start_idx, end_idx}
        year: year of patent to process
        save_dir: directory to save processed cosine matrices
        chunk_size: size of the chunk to process
        top_k: top k most similar patents to keep
        cutoff: cutoff similarity value
        m_sample: m patents to randomly sample
        start_idx: starting index to process
    Returns:
        None
    """
    if year <= 1976 or year > 2015:
        raise ValueError("year must be integer between 1977 and 2015.")
    if cutoff >= 1 or cutoff <= 0:
        raise ValueError("cosine similarity cutoff must be between 0 and 1.")

    min_idx, max_idx = year_idx[year]["patent_idx"]

    year_patent = patent[min_idx:max_idx, :]  # get patents in the given year
    prev_patent = patent[:min_idx, :]  # get patents published before the given year

    print(f'year_patent.shape: {year_patent.shape}')
    print(f'prev_patent.shape: {prev_patent.shape}')
    logger.info(f"year_patent.shape: {year_patent.shape}")
    logger.info(f"prev_patent.shape: {prev_patent.shape}")
    n_patent = int(year_patent.shape[0])

    # Save directories for different selection criteria
    save_folder_topk = os.path.join(save_dir, f"sim_{year}_topk")
    save_folder_cutoff = os.path.join(save_dir, f"sim_{year}_cutoff")
    save_folder_sample = os.path.join(save_dir, f"sim_{year}_sample")
    os.makedirs(save_folder_topk, exist_ok=True)
    os.makedirs(save_folder_cutoff, exist_ok=True)
    os.makedirs(save_folder_sample, exist_ok=True)

    print(f"number of patents in this year = {n_patent}, chunk_size = {chunk_size}")
    logger.info(f"number of patents in this year = {n_patent}, chunk_size = {chunk_size}")

    for i, chunk in enumerate(range(0, n_patent, chunk_size)):
        if i < start_idx:
            continue
        else:
            print(f"processing {i}th chunk")
            logger.info(f"processing {i}th chunk")
            sim_mat = cosine_similarity(
                year_patent[chunk : chunk + chunk_size, :], prev_patent, dense_output=True
            )

            # Top-k similar patents
            top_k_idx = np.argpartition(sim_mat, -top_k, axis=1)[:, :-top_k - 1 : -1]
            top_k_sim_score = sim_mat[np.arange(sim_mat.shape[0])[:, None], top_k_idx]

            # Cosine similarity >x
            indices_by_row = [np.where(row > cutoff)[0].tolist() for row in sim_mat]
            values_by_row = [row[row > cutoff].tolist() for row in sim_mat]

            # Random sample of k patents
            num_rows, num_cols = sim_mat.shape
            random_indices = np.random.choice(num_cols, (num_rows, m_sample), replace=False)
            sampled_matrix = sim_mat[np.arange(num_rows)[:, None], random_indices]

            if min_idx + chunk + chunk_size < max_idx:
                patent_index = list(range(min_idx + chunk, min_idx + chunk + chunk_size))
            else:
                patent_index = list(range(min_idx + chunk, max_idx))

            del sim_mat

            # Save results for different selection criteria
            save_topk_results(patent_index, top_k_idx, top_k_sim_score, year, i, top_k, save_folder_topk)
            save_cutoff_results(patent_index, indices_by_row, values_by_row, year, i, cutoff, save_folder_cutoff)
            save_sample_results(patent_index, random_indices, sampled_matrix, year, i, m_sample, save_folder_sample)


def save_topk_results(patent_index, top_k_idx, top_k_sim_score, year, i, top_k, save_folder_topk):
    df = pd.DataFrame(index=patent_index, data=top_k_idx)
    df = df.reset_index()
    df.rename(columns={'index': 'patent_idx'}, inplace=True)
    df = pd.melt(df, id_vars='patent_idx', value_name='cited_patent_idx')
    df.drop(columns=['variable'], inplace=True)

    df_sc = pd.DataFrame(index=patent_index, data=top_k_sim_score)
    df_sc = df_sc.reset_index()
    df_sc.rename(columns={'index': 'patent_idx'}, inplace=True)
    df_sc = pd.melt(df_sc, id_vars='patent_idx', value_name='sim_score')
    df_sc.drop(columns=['variable', 'patent_idx'], inplace=True)

    df = pd.concat([df, df_sc], axis=1)
    output_path = os.path.join(save_folder_topk, f"sim_{year}_{i}_top{top_k}.csv")
    df.to_csv(output_path, index=False)
    del df, df_sc, top_k_sim_score, top_k_idx

def save_cutoff_results(patent_index, indices_by_row, values_by_row, year, i, cutoff, save_folder_cutoff):
    df_idx = pd.DataFrame({'patent_idx': patent_index, 'cited_patent_idx': indices_by_row})
    df_idx = df_idx.explode('cited_patent_idx').reset_index(drop=True)

    df_val = pd.DataFrame({'patent_idx': patent_index, 'sim_score': values_by_row})
    df_val = df_val.explode('sim_score').reset_index(drop=True)

    df = pd.concat([df_idx, df_val['sim_score']], axis=1)
    output_path = os.path.join(save_folder_cutoff, f"sim_{year}_{i}_cutoff65.csv")
    df.to_csv(output_path, index=False)
    del df, df_idx, df_val, indices_by_row, values_by_row

def save_sample_results(patent_index, random_indices, sampled_matrix, year, i, m_sample, save_folder_sample):
    df = pd.DataFrame(index=patent_index, data=random_indices)
    df = df.reset_index()
    df.rename(columns={'index': 'patent_idx'}, inplace=True)
    df = pd.melt(df, id_vars='patent_idx', value_name='cited_patent_idx')
    df.drop(columns=['variable'], inplace=True)

    df_sc = pd.DataFrame(index=patent_index, data=sampled_matrix)
    df_sc = df_sc.reset_index()
    df_sc.rename(columns={'index': 'patent_idx'}, inplace=True)
    df_sc = pd.melt(df_sc, id_vars='patent_idx', value_name='sim_score')
    df_sc.drop(columns=['variable', 'patent_idx'], inplace=True)

    df = pd.concat([df, df_sc], axis=1)
    output_path = os.path.join(save_folder_sample, f"sim_{year}_{i}_sample{m_sample}.csv")
    df.to_csv(output_path, index=False)
    del df, df_sc, random_indices, sampled_matrix


def main(patent_dict_path, patent_matrix_path, save_dir, chunk_size, top_k, cutoff, m_sample, start_idx):
    print('Loading embedding...')
    year_idx = create_year_idx(patent_dict_path, min_year=1976, max_year=2015)
    patent = np.load(patent_matrix_path)
    logging.warning(f"loaded patent embedding matrix of shape {patent.shape}, this may take large memory.")

    for year in range(1981, 2016):
        print(f'Year {year}')
        compute_and_save_similarities(
            patent,
            year_idx,
            year,
            save_dir,
            chunk_size,
            top_k,
            cutoff,
            m_sample,
            start_idx,
        )

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument('--path_patent', help='specify path to patent embedding npy')
    ap.add_argument('--path_patent_dict')
    ap.add_argument('--save_dir')
    ap.add_argument('--chunk_size', type=int, default=1000)
    ap.add_argument('--top_k', type=int, default=101)
    ap.add_argument('--cutoff', type=float, default=0.65)
    ap.add_argument('--m_sample', type=int, default=101)
    ap.add_argument('--start_idx', type=int, default=0)
    args = ap.parse_args()

    WORK_DIR = r'/Volumes/Zihao_SSD/PatentsView/'
    # change log name to patentsberta_kpss.log when doing robustness check of k
    logging.basicConfig(
        filename=os.path.join(WORK_DIR, 'logs', 'patentsberta_combined.log'),
        format='%(asctime)s:%(levelname)s:%(message)s',
        level=logging.INFO)
    logger = logging.getLogger(__name__)

    main(args.path_patent_dict, args.path_patent, args.save_dir, args.chunk_size, args.top_k, args.cutoff, args.m_sample, args.start_idx)